library(here)
library(data.table)
library(PanelMatch)
library(tidyverse)

fullData <- fread(here('data', 'full_rol.csv'))

# Analysis --------------------------------------------------------------------------

## Create Subsets -------------------------------------------------------------------

fullData$low_rol <- 1

for (r in which(fullData$v2x_rule > .8 & fullData$ml_trans == 1)) {
  country <- fullData[r, iso3c]
  fullData[iso3c == country & ml == 1, low_rol := 0]
}

rm(country, r)


## PanelMatch -----------------------------------------------------------------------

set.seed(10)

pm_low <- PanelMatch(lag = 5, lead = 0:5,
                     time.id = "year", unit.id = "cid_pm", 
                     treatment = "ml", 
                     outcome.var = "v2x_rule",
                     refinement.method = "CBPS.weight", 
                     covs.formula = ~ lag(v2x_rule, 1) +
                       lag(bits_ln, 1) +
                       lag(market_ln, 1) +
                       lag(gdppc_ln, 1) +
                       lag(trade_ln, 1) +
                       lag(fdi_stock_ln, 1) +
                       lag(growth, 1) +
                       lag(nyc, 1),
                     qoi = "att", 
                     use.diagonal.variance.matrix = TRUE,
                     forbid.treatment.reversal = TRUE,
                     data = as.data.frame(fullData[fullData$low_rol == 1, ]))

# Disaggregated results -------------------------------------------------------------

namesList <- names(pm_low$att)

# Create holder dfs
weightsDf <- data.frame()
diffDf <- data.frame()

for (i in 1:length(namesList)) {
  # Separate country name and treatment year
  # Then create df with weights for that matched set
  countryInfo <- strsplit(namesList[i], "[_\\.]")
  tempDf <- data.frame(cid_pm = countryInfo[[1]][1], 
                       treatYear = as.integer(countryInfo[[1]][2]))
  tempDf <- cbind(tempDf, 
                  t( data.frame( attributes(pm_low[["att"]][[namesList[i]]])$weights ) ) 
  )
  rownames(tempDf) <- NULL
  weightsDf <- bind_rows(weightsDf, tempDf)
  
  # Create df of sample country values
  sampleCountries <- colnames(tempDf[3:ncol(tempDf)])
  
  cntrlDf <- data.frame()
  
  # Create df with weighted values
  for (p in 1:length(sampleCountries)) {
    # Grab values for matched unit P for 1 year prior to treatment and 5 years post.
    dataHolder <- fullData[fullData$cid_pm == sampleCountries[p]
                           & fullData$year >= (tempDf$treatYear - 5)
                           & fullData$year <= (tempDf$treatYear + 5), 
                           c("cid_pm", "year", "v2x_rule")]
    dataHolder$did <- 0
    
    for (d in 1:nrow(dataHolder)) {
      dataHolder$did[d] <- dataHolder$v2x_rule[d] - dataHolder$v2x_rule[dataHolder$year == (tempDf$treatYear - 1)]
    }
    dataHolder$weightedDID <- dataHolder$did * tempDf[ , sampleCountries[p] ]
    cntrlDf <- bind_rows(cntrlDf, dataHolder)
  }
  
  cntrlDf
  # Create holder for matched set weighted DiD
  cntrlHolder <- data.frame(cntrlDID = tapply(cntrlDf$weightedDID, cntrlDf$year, sum))
  cntrlHolder$ml_distance <- as.integer(rownames(cntrlHolder)) - tempDf$treatYear
  # Create treated unit DiD
  treatedValues <- fullData[fullData$cid_pm == tempDf$cid_pm
                            & fullData$year >= (tempDf$treatYear - 5)
                            & fullData$year <= (tempDf$treatYear + 5), 
                            c("cid_pm", "year", "v2x_rule")]
  treatedValues$did <- 0
  for (d in 1:nrow(treatedValues)) {
    treatedValues$did[d] <- treatedValues$v2x_rule[d] - treatedValues$v2x_rule[treatedValues$year == (tempDf$treatYear - 1)]
  }
  
  diffHolder <- data.frame(cid_pm = rep(tempDf$cid_pm[1], nrow(treatedValues)),
                           treatYear = rep(tempDf$treatYear[1], nrow(treatedValues)),
                           ml_distance = cntrlHolder$ml_distance, 
                           diff = treatedValues$did - cntrlHolder$cntrlDID,
                           didTreated = treatedValues$did,
                           didControl = cntrlHolder$cntrlDID)
  
  diffDf <- bind_rows(diffDf, diffHolder)
}
rm(i, tempDf, countryInfo, cntrlHolder, 
   namesList, p, sampleCountries, diffHolder, 
   dataHolder, treatedValues, cntrlDf, d)

diffDf <- merge(diffDf, fullData[,c("cid_pm", "year", "v2x_rule")],
                by.x = c("cid_pm", "treatYear"), by.y = c("cid_pm", "year"),
                all.x = TRUE)

# Plot results ----------------------------------------------------------------------

treatedDelta <- data.frame(time = -5:5, 
                           treatDelta = tapply(diffDf$didTreated, diffDf$ml_distance, mean))
cntrlDelta   <- data.frame(time = -5:5, 
                           cntrlDelta = tapply(diffDf$didControl, diffDf$ml_distance, mean))
diffs        <- left_join(treatedDelta, cntrlDelta, by = "time")
diffs        <- gather(diffs, key = "status", value = "diff", -time)

rm(weightsDf, cntrlDelta, treatedDelta)

ggplot(data = diffs, aes(x = time, y = diff, color = status)) +
  geom_hline(yintercept = 0, 
             color = 'red', 
             size = .5, 
             linetype = 'dashed') +
  geom_vline(xintercept = -.5, 
             color = 'black', 
             size = .5, 
             linetype = 'dashed', 
             alpha = .5) +
  geom_line(size = 1, show.legend = FALSE) +
  geom_point(size = 4.5, aes(shape = status)) +
  labs(title = "", x = "Years Since Model Law Enactment", 
       y = "Change in Rule of Law from t-1", 
       color = "", shape = "") +
  scale_x_continuous(breaks = -5:5) +
  scale_y_continuous(limits = c(-.03,.03),
                     breaks = seq(-.03, .03, .01)) +
  scale_colour_manual(labels = c("No Model Law", "Model Law"), 
                      values = c("grey", "black")) +
  scale_shape_manual(labels = c("No Model Law", "Model Law"), 
                     values=c(19, 17)) +
  theme_bw(base_size = 14) +
  theme(panel.grid.major.x = element_blank(), 
        panel.grid.minor.x = element_blank(),
        panel.grid.minor.y = element_blank(), 
        panel.grid.major.y = element_blank(),
        legend.position = c(0.22, 0.83),
  )

ggsave(here('output', 'figures', 'figure-4.pdf'),
       height = 4, width = 6)

